{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
    "\n",
    "*Licensed under the MIT License.*\n",
    "\n",
    "# Text Classification of Multi Language Datasets using Transformer Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scrapbook as sb\n",
    "import pandas as pd\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from tempfile import TemporaryDirectory\n",
    "from utils_nlp.common.timer import Timer\n",
    "from sklearn.metrics import classification_report\n",
    "from utils_nlp.models.transformers.sequence_classification import SequenceClassifier\n",
    "\n",
    "from utils_nlp.dataset import multinli\n",
    "from utils_nlp.dataset import dac\n",
    "from utils_nlp.dataset import bbc_hindi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "In this notebook, we fine-tune and evaluate a pretrained Transformer model using BERT earchitecture on three different language datasets:\n",
    "\n",
    "- [MultiNLI dataset](https://www.nyu.edu/projects/bowman/multinli/): The Multi-Genre NLI corpus, in English\n",
    "- [DAC dataset](https://data.mendeley.com/datasets/v524p5dhpj/2): DataSet for Arabic Classification corpus, in Arabic\n",
    "- [BBC Hindi dataset](https://github.com/NirantK/hindi2vec/releases/tag/bbc-hindi-v0.1): BBC Hindi News corpus, in Hindi\n",
    "\n",
    "If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. You can also choose a dataset from three existing datasets (**`MultNLI`**, **`DAC`**, and **`BBC Hindi`**) to experiment. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running Time\n",
    "\n",
    "The table below provides some reference running time on different datasets.  \n",
    "\n",
    "|Dataset|QUICK_RUN|Machine Configurations|Running time|\n",
    "|:------|:---------|:----------------------|:------------|\n",
    "|MultiNLI|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 8 minutes |\n",
    "|MultiNLI|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5.7 hours |\n",
    "|DAC|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 13 minutes |\n",
    "|DAC|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5.6 hours |\n",
    "|BBC Hindi|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute |\n",
    "|BBC Hindi|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 14 minutes |\n",
    "\n",
    "If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `batch_size` and `max_len` in `CONFIG`, but note that model performance may be compromised. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
    "QUICK_RUN = True\n",
    "\n",
    "# the dataset you want to try, valid values are: \"multinli\", \"dac\", \"bbc-hindi\"\n",
    "USE_DATASET = \"dac\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For text classification, the following pretrained models are supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>bert-base-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bert-large-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>bert-base-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>bert-large-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>bert-base-multilingual-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>bert-base-multilingual-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>bert-base-chinese</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>bert-base-german-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>bert-large-uncased-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>bert-large-cased-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>bert-large-uncased-whole-word-masking-finetune...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>bert-large-cased-whole-word-masking-finetuned-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>bert-base-cased-finetuned-mrpc</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>bert-base-german-dbmdz-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>bert-base-german-dbmdz-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>bert-base-japanese</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>bert-base-japanese-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>bert-base-japanese-char</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>bert-base-japanese-char-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>bert-base-finnish-cased-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>bert-base-finnish-uncased-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>roberta-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>roberta-large</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>roberta-large-mnli</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>distilroberta-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>roberta-base-openai-detector</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>roberta-large-openai-detector</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>xlnet-base-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>xlnet-large-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>distilbert-base-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>distilbert-base-uncased-distilled-squad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>distilbert-base-german-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>distilbert-base-multilingual-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>albert-base-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>albert-large-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>albert-xlarge-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>albert-xxlarge-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>albert-base-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>albert-large-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>albert-xlarge-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>albert-xxlarge-v2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           model_name\n",
       "0                                   bert-base-uncased\n",
       "1                                  bert-large-uncased\n",
       "2                                     bert-base-cased\n",
       "3                                    bert-large-cased\n",
       "4                      bert-base-multilingual-uncased\n",
       "5                        bert-base-multilingual-cased\n",
       "6                                   bert-base-chinese\n",
       "7                              bert-base-german-cased\n",
       "8               bert-large-uncased-whole-word-masking\n",
       "9                 bert-large-cased-whole-word-masking\n",
       "10  bert-large-uncased-whole-word-masking-finetune...\n",
       "11  bert-large-cased-whole-word-masking-finetuned-...\n",
       "12                     bert-base-cased-finetuned-mrpc\n",
       "13                       bert-base-german-dbmdz-cased\n",
       "14                     bert-base-german-dbmdz-uncased\n",
       "15                                 bert-base-japanese\n",
       "16              bert-base-japanese-whole-word-masking\n",
       "17                            bert-base-japanese-char\n",
       "18         bert-base-japanese-char-whole-word-masking\n",
       "19                         bert-base-finnish-cased-v1\n",
       "20                       bert-base-finnish-uncased-v1\n",
       "21                                       roberta-base\n",
       "22                                      roberta-large\n",
       "23                                 roberta-large-mnli\n",
       "24                                 distilroberta-base\n",
       "25                       roberta-base-openai-detector\n",
       "26                      roberta-large-openai-detector\n",
       "27                                   xlnet-base-cased\n",
       "28                                  xlnet-large-cased\n",
       "29                            distilbert-base-uncased\n",
       "30            distilbert-base-uncased-distilled-squad\n",
       "31                       distilbert-base-german-cased\n",
       "32                 distilbert-base-multilingual-cased\n",
       "33                                     albert-base-v1\n",
       "34                                    albert-large-v1\n",
       "35                                   albert-xlarge-v1\n",
       "36                                  albert-xxlarge-v1\n",
       "37                                     albert-base-v2\n",
       "38                                    albert-large-v2\n",
       "39                                   albert-xlarge-v2\n",
       "40                                  albert-xxlarge-v2"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame({\"model_name\": SequenceClassifier.list_supported_models()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to demonstrate multi language capability of Transformer models, we only use the model **`bert-base-multilingual-cased`** by default in this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = {\n",
    "    'local_path': TemporaryDirectory().name,\n",
    "    'test_fraction': 0.2,\n",
    "    'random_seed': 100,\n",
    "    'train_sample_ratio': 1.0,\n",
    "    'test_sample_ratio': 1.0,\n",
    "    'model_name': 'distilbert-base-multilingual-cased',\n",
    "    'to_lower': False,\n",
    "    'cache_dir': TemporaryDirectory().name,\n",
    "    'max_len': 150,\n",
    "    'num_train_epochs': 5,\n",
    "    'num_gpus': 2,\n",
    "    'batch_size': 16,\n",
    "    'verbose': False,\n",
    "    'load_dataset_func': None,\n",
    "    'get_labels_func': None\n",
    "}\n",
    "\n",
    "if QUICK_RUN:\n",
    "    CONFIG['train_sample_ratio'] = 0.2\n",
    "    CONFIG['test_sample_ratio'] = 0.2\n",
    "    CONFIG['num_train_epochs'] = 1\n",
    "\n",
    "torch.manual_seed(CONFIG['random_seed'])\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    CONFIG['batch_size'] = 32\n",
    "    \n",
    "if USE_DATASET == \"multinli\":\n",
    "    CONFIG['to_lower'] = True\n",
    "    CONFIG['load_dataset_func'] = multinli.load_tc_dataset\n",
    "    CONFIG['get_labels_func'] = multinli.get_label_values\n",
    "    \n",
    "    if QUICK_RUN:\n",
    "        CONFIG['train_sample_ratio'] = 0.1\n",
    "        CONFIG['test_sample_ratio'] = 0.1\n",
    "elif USE_DATASET == \"dac\":\n",
    "    CONFIG['load_dataset_func'] = dac.load_tc_dataset\n",
    "    CONFIG['get_labels_func'] = dac.get_label_values\n",
    "elif USE_DATASET == \"bbc-hindi\":\n",
    "    CONFIG['load_dataset_func'] = bbc_hindi.load_tc_dataset\n",
    "    CONFIG['get_labels_func'] = bbc_hindi.get_label_values\n",
    "else:\n",
    "    raise ValueError(\"Supported datasets are: 'multinli', 'dac', and 'bbc-hindi'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Dataset\n",
    "\n",
    "By choosing the dataset you want to experiment with, the code snippet below will adaptively seletct a helper function **`load_dataset`** for the dataset.  The helper function downloads the raw data, splits it into training and testing datasets (also sub-sampling if the sampling ratio is smaller than 1.0), and then processes for the transformer model. Everything is done in one function call, and you can use the processed training and testing Pytorch datasets to fine tune the model and evaluate the performance of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 80.1k/80.1k [00:02<00:00, 30.8kKB/s]\n",
      "/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
      "  FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "train_dataloader, test_dataloader, label_encoder, test_labels = CONFIG['load_dataset_func'](\n",
    "    local_path=CONFIG['local_path'],\n",
    "    test_fraction=CONFIG['test_fraction'],\n",
    "    random_seed=CONFIG['random_seed'],\n",
    "    train_sample_ratio=CONFIG['train_sample_ratio'],\n",
    "    test_sample_ratio=CONFIG['test_sample_ratio'],\n",
    "    model_name=CONFIG['model_name'],\n",
    "    to_lower=CONFIG['to_lower'],\n",
    "    cache_dir=CONFIG['cache_dir'],\n",
    "    max_len=CONFIG['max_len'],\n",
    "    num_gpus=CONFIG['num_gpus']\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine Tune\n",
    "\n",
    "There are two steps to fine tune a transformer model for text classifiction: 1). instantiate a `SequenceClassifier` class which is a wrapper of the transformer model, and 2), fit the model using the preprocessed training dataset. The member method `fit` of `SequenceClassifier` class is used to fine tune the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training time : 0.190 hrs\n"
     ]
    }
   ],
   "source": [
    "model = SequenceClassifier(\n",
    "    model_name=CONFIG['model_name'],\n",
    "    num_labels=len(label_encoder.classes_),\n",
    "    cache_dir=CONFIG['cache_dir']\n",
    ")\n",
    "\n",
    "# Fine tune the model using the training dataset\n",
    "with Timer() as t:\n",
    "    model.fit(\n",
    "        train_dataloader=train_dataloader,\n",
    "        num_epochs=CONFIG['num_train_epochs'],\n",
    "        num_gpus=CONFIG['num_gpus'],\n",
    "        verbose=CONFIG['verbose'],\n",
    "        seed=CONFIG['random_seed']\n",
    "    )\n",
    "\n",
    "print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate on Testing Dataset\n",
    "\n",
    "The `predict` method of the `SequenceClassifier` returns a Numpy ndarray of raw predictions. Each predicting value is a label ID, and if you want to get the label values you will need to call function `get_label_values` from the dataset package. An instance of sklearn `LabelEncoder` is returned when loading the dataset and can be used to get the mapping between label ID and label value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction time : 0.021 hrs\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    preds = model.predict(\n",
    "        test_dataloader=test_dataloader,\n",
    "        num_gpus=CONFIG['num_gpus'],\n",
    "        verbose=CONFIG['verbose']\n",
    "    )\n",
    "\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we compute the precision, recall, and F1 metrics of the evaluation on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "     culture       0.93      0.94      0.93       548\n",
      "     diverse       0.94      0.94      0.94       640\n",
      "     economy       0.90      0.88      0.89       570\n",
      "    politics       0.87      0.88      0.88       809\n",
      "      sports       0.99      0.98      0.99      1785\n",
      "\n",
      "   micro avg       0.94      0.94      0.94      4352\n",
      "   macro avg       0.93      0.93      0.93      4352\n",
      "weighted avg       0.94      0.94      0.94      4352\n",
      "\n"
     ]
    }
   ],
   "source": [
    "report = classification_report(\n",
    "    test_labels, \n",
    "    preds,\n",
    "    digits=2,\n",
    "    labels=np.unique(test_labels),\n",
    "    target_names=label_encoder.classes_\n",
    ")\n",
    "\n",
    "print(report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.94,
       "encoder": "json",
       "name": "precision",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "precision"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.94,
       "encoder": "json",
       "name": "recall",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "recall"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.94,
       "encoder": "json",
       "name": "f1",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "f1"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# for testing\n",
    "report_splits = report.split('\\n')[-2].split()\n",
    "\n",
    "sb.glue(\"precision\", float(report_splits[2]))\n",
    "sb.glue(\"recall\", float(report_splits[3]))\n",
    "sb.glue(\"f1\", float(report_splits[4]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.8 64-bit ('nlp_gpu': conda)",
   "language": "python",
   "name": "python36864bitnlpgpucondaa579511bcea84c65877ff3dca4205921"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
